{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Network-Blending-ADA-PT.ipynb",
      "private_outputs": true,
      "provenance": [],
      "machine_shape": "hm",
      "mount_file_id": "17KlZS6Ax5HEcJDhW0EVAsrOrXBeTVqxv",
      "authorship_tag": "ABX9TyN0OSZzbXfWmS2HYwtpQSn3",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/dvschultz/stylegan2-ada-pytorch/blob/main/Network_Blending_ADA_PT.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CdkWE1TKUWwB"
      },
      "source": [
        "# Network Blending\n",
        "This demo will show how to combine two separate StyleGAN2-ADA-PyTorch models into one by splitting their weights at a specified layer.\n",
        "\n",
        "This example was created by Derrick Schultz for his Advanced StyleGAN2 class. It’s a simpler version of [Justin Pinkney’s Tensorflow version](https://github.com/justinpinkney/stylegan2/blob/master/blend_models.py).\n",
        "\n",
        "---\n",
        "\n",
        "If you find this notebook useful, consider signing up for my [Patreon](https://www.patreon.com/bustbright) or [YouTube channel](https://www.youtube.com/channel/UCaZuPdmZ380SFUMKHVsv_AA/join). You can also send me a one-time payment on [Venmo](https://venmo.com/Derrick-Schultz).\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aJFXX8WIBeqy"
      },
      "source": [
        "!nvidia-smi -L"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hvwu1LpCOJtz"
      },
      "source": [
        "!git clone https://github.com/dvschultz/stylegan2-ada-pytorch\n",
        "%cd stylegan2-ada-pytorch\n",
        "!pip install ninja opensimplex"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O09iXoAPQsHX"
      },
      "source": [
        "## Download two models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p9dwvbtWp97y"
      },
      "source": [
        "!wget http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl\n",
        "!wget http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kRQUzZK2HHSt"
      },
      "source": [
        "!gdown --id 15GpzB-wTwGIZC_Wu0ruaEJi7-giRWOOo -O /content/bone-bone.pkl\n",
        "!wget http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FnHs7WBuqXDo"
      },
      "source": [
        "!python legacy.py --source=/content/stylegan2-ada-pytorch/stylegan2-ffhq-config-f.pkl --dest=/content/ffhq-pt.pkl\n",
        "!python legacy.py --source=/content/bone-bone.pkl --dest=/content/bone-bone-pt.pkl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c_kFpXUAvcTP"
      },
      "source": [
        "## Script Example\n",
        "If you want to simply run the command as a script, you run the cell below."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s_P9TtAsv_xW"
      },
      "source": [
        "!python blend_models.py --help"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A-S0KjY1vuzj"
      },
      "source": [
        "!python blend_models.py --lower_res_pkl /content/ffhq-pt.pkl --split_res 64 --higher_res_pkl /content/bone-bone-pt.pkl --output_path /content/ffhq-bonebone-split64.pkl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4uDMVyB0UCiS"
      },
      "source": [
        "## Code example\n",
        "\n",
        "If you want to see under the hood here’s how this works."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ENNGfqvKS59e"
      },
      "source": [
        "import os\n",
        "import copy\n",
        "import numpy as np\n",
        "import torch\n",
        "import pickle\n",
        "import dnnlib\n",
        "import legacy\n",
        "\n",
        "def extract_conv_names(model, model_res):\n",
        "    model_names = list(name for name,weight in model.named_parameters())\n",
        "\n",
        "    return model_names\n",
        "\n",
        "def blend_models(low, high, model_res, resolution, level, blend_width=None):\n",
        "\n",
        "    resolutions =  [4*2**x for x in range(int(np.log2(resolution)-1))]\n",
        "    print(resolutions)\n",
        "    \n",
        "    low_names = extract_conv_names(low, model_res)\n",
        "    high_names = extract_conv_names(high, model_res)\n",
        "\n",
        "    assert all((x == y for x, y in zip(low_names, high_names)))\n",
        "\n",
        "    #start with lower model and add weights above\n",
        "    model_out = copy.deepcopy(low)\n",
        "    params_src = high.named_parameters()\n",
        "    dict_dest = model_out.state_dict()\n",
        "\n",
        "    for name, param in params_src:\n",
        "        if not any(f'synthesis.b{res}' in name for res in resolutions) and not ('mapping' in name):\n",
        "            # print(name)\n",
        "            dict_dest[name].data.copy_(param.data)\n",
        "\n",
        "    model_out_dict = model_out.state_dict()\n",
        "    model_out_dict.update(dict_dest) \n",
        "    model_out.load_state_dict(dict_dest)\n",
        "    \n",
        "    return model_out"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ovSNBUBxOeXt"
      },
      "source": [
        "lo_res_pkl = '/content/freagan-pt.pkl'\n",
        "hi_res_pkl = '/content/ladiescrop.pkl'\n",
        "model_res = 1024\n",
        "level = 0\n",
        "blend_width=None\n",
        "out = '/content/blend-frea-ladiestransfer-128.pkl'\n",
        "\n",
        "G_kwargs = dnnlib.EasyDict()\n",
        "\n",
        "with dnnlib.util.open_url(lo_res_pkl) as f:\n",
        "    # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore\n",
        "    lo = legacy.load_network_pkl(f, custom=False, **G_kwargs) # type: ignore\n",
        "    lo_G, lo_D, lo_G_ema = lo['G'], lo['D'], lo['G_ema']\n",
        "\n",
        "with dnnlib.util.open_url(hi_res_pkl) as f:\n",
        "    # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore\n",
        "    hi = legacy.load_network_pkl(f, custom=False, **G_kwargs)['G_ema'] # type: ignore\n",
        "    #hi_G, hi_D, hi_G_ema = hi['G'], lo['D'], lo['G_ema']\n",
        "\n",
        "rezes = [8,16,32,64,128]\n",
        "for r in rezes: \n",
        "    model_out = blend_models(lo_G_ema, hi, model_res, r, level, blend_width=blend_width)\n",
        "\n",
        "    # for n in model_out.named_parameters():\n",
        "    #     print(n[0])\n",
        "\n",
        "    #save new pkl file\n",
        "    out = f'/content/blend-frea-ladiestransfer-{r}.pkl'\n",
        "    data = dict([('G', None), ('D', None), ('G_ema', None)])\n",
        "    with open(out, 'wb') as f:\n",
        "        #misc.save_pkl((low_res_G, low_res_D, out), output_pkl)\n",
        "        data['G'] = lo_G\n",
        "        data['D'] = lo_D\n",
        "        data['G_ema'] = model_out\n",
        "        pickle.dump(data, f)\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M0XUCmyJUJ66"
      },
      "source": [
        "## Test Generating Images With Your New Model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qEe21is5r1qV"
      },
      "source": [
        "for r in rezes:\n",
        "    !python generate.py --outdir=/content/out/blended-frea-ladiestransfer2-{r}/ --trunc=0.6 --seeds=0-24 --network=/content/blend-frea-ladiestransfer-{r}.pkl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "B8WmzAmLsL8I"
      },
      "source": [
        "!zip -r transferred-blends_r2.zip /content/out"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OF258_TAuFV1"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}